{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5d4c33a1-a009-4d41-aa04-f100d012f6ce",
   "metadata": {},
   "source": [
    "# Scalable late interaction vectors in Elasticsearch: Average Vectors #\n",
    "\n",
    "In this notebook, we will be looking at how scale search with late interaction models. We will be taking the average vector over our late interaction multi-vectors and use Elasticsearchs vector search capabilities to achieve scalable search over billions of vectors. \n",
    "   \n",
    "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n",
    "\n",
    "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81fec537-e52b-4e03-967f-b227a5349cae",
   "metadata": {},
   "source": [
    "This is the key part of this notebook. We use `to_avg_vector(vectors)` to convert our 2d array into a single vector that holds the \"average meaning\" of the document. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def to_bit_vectors(embeddings: list) -> list:\n",
    "    return [\n",
    "        np.packbits(np.where(np.array(embedding) > 0, 1, 0))\n",
    "        .astype(np.int8)\n",
    "        .tobytes()\n",
    "        .hex()\n",
    "        for embedding in embeddings\n",
    "    ]\n",
    "\n",
    "\n",
    "def to_avg_vector(vectors):\n",
    "    vectors_array = np.array(vectors)\n",
    "\n",
    "    avg_vector = np.mean(vectors_array, axis=0)\n",
    "\n",
    "    norm = np.linalg.norm(avg_vector)\n",
    "    if norm > 0:\n",
    "        normalized_avg_vector = avg_vector / norm\n",
    "    else:\n",
    "        normalized_avg_vector = avg_vector\n",
    "\n",
    "    return normalized_avg_vector.tolist()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80d0bb03-36e9-4050-83c9-0cb54486842b",
   "metadata": {},
   "source": [
    "Here we are defining our mapping for our Elasticsearch index. Note how we are using the `dense_vector` field type for our average vector.  \n",
    "This allows us to leverage the highly optimized HNSW indexing structures in Elasticsearch with which we can scale our search to billions of documents. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Creating index: searchlabs-colpali-average-vector\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "load_dotenv(\"elastic.env\")\n",
    "\n",
    "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
    "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
    "INDEX_NAME = \"searchlabs-colpali-average-vector\"\n",
    "\n",
    "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
    "\n",
    "mappings = {\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"avg_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 128,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"dot_product\",\n",
    "            },\n",
    "            \"col_pali_vectors\": {\"type\": \"rank_vectors\", \"element_type\": \"bit\"},\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "if not es.indices.exists(index=INDEX_NAME):\n",
    "    print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
    "    es.indices.create(index=INDEX_NAME, body=mappings)\n",
    "else:\n",
    "    print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
    "\n",
    "\n",
    "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
    "    for attempt in range(1, retries + 1):\n",
    "        try:\n",
    "            return es_client.index(index=index, id=doc_id, document=document)\n",
    "        except Exception as e:\n",
    "            if attempt < retries:\n",
    "                wait_time = initial_backoff * (2 ** (attempt - 1))\n",
    "                print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
    "                time.sleep(wait_time)\n",
    "            else:\n",
    "                print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
    "                raise"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "662569b6-77d0-4a0d-8ad6-5d698cf0404c",
   "metadata": {},
   "source": [
    "Here we are looping over all our vectors and convert them into our average vectors.  \n",
    "We still save our full fidelity ColPali vectors as bit vectors as we can use them for reranking. More on that later. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb2e8a4c74b5494c8203308df49ec750",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Indexing documents:   0%|          | 0/500 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Completed indexing 500 documents\n"
     ]
    }
   ],
   "source": [
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "\n",
    "\n",
    "def process_file(file_name, vectors):\n",
    "    if es.exists(index=INDEX_NAME, id=file_name):\n",
    "        return\n",
    "\n",
    "    avg_vector = to_avg_vector(vectors)\n",
    "\n",
    "    bit_vectors = to_bit_vectors(vectors)\n",
    "\n",
    "    index_document(\n",
    "        es_client=es,\n",
    "        index=INDEX_NAME,\n",
    "        doc_id=file_name,\n",
    "        document={\"avg_vector\": avg_vector, \"col_pali_vectors\": bit_vectors},\n",
    "    )\n",
    "\n",
    "\n",
    "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
    "    file_to_multi_vectors = pickle.load(f)\n",
    "\n",
    "with ThreadPoolExecutor(max_workers=10) as executor:\n",
    "    list(\n",
    "        tqdm(\n",
    "            executor.map(\n",
    "                lambda item: process_file(*item), file_to_multi_vectors.items()\n",
    "            ),\n",
    "            total=len(file_to_multi_vectors),\n",
    "            desc=\"Indexing documents\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1dfc3713-d649-46db-aa81-171d6d92668e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdf787df90d24a289fa70e6abd0d9597",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from colpali_engine.models import ColPali, ColPaliProcessor\n",
    "\n",
    "model_name = \"vidore/colpali-v1.3\"\n",
    "model = ColPali.from_pretrained(\n",
    "    \"vidore/colpali-v1.3\",\n",
    "    torch_dtype=torch.float32,\n",
    "    device_map=\"mps\",  # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
    ").eval()\n",
    "\n",
    "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
    "\n",
    "\n",
    "def create_col_pali_query_vectors(query: str) -> list:\n",
    "    queries = col_pali_processor.process_queries([query]).to(model.device)\n",
    "    with torch.no_grad():\n",
    "        return model(**queries).tolist()[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae89aa6f-7022-4ee7-b930-7e1e94c9abfc",
   "metadata": {},
   "source": [
    "Again we create our query vector by using the ColPali model and calculating the average vector. We can now use Elasticsearches `knn` query to compare our single vector to the average vectors within our index. \n",
    "Notice that the document that we have found in our previous examples with the title *Social Media for Recruitment* is now in 5th position. See the next cell on how we can handle this. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_250.jpg\" alt=\"image_250.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display, HTML\n",
    "import os\n",
    "import json\n",
    "\n",
    "DOCUMENT_DIR = \"searchlabs-colpali\"\n",
    "\n",
    "query = \"What do companies use for recruiting?\"\n",
    "query_vector = to_avg_vector(create_col_pali_query_vectors(query))\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"knn\": {\n",
    "        \"field\": \"avg_vector\",\n",
    "        \"query_vector\": query_vector,\n",
    "        \"k\": 10,\n",
    "        \"num_candidates\": 100,\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cf39408-c1b3-4832-a2a6-4c53fb4575d7",
   "metadata": {},
   "source": [
    "In the cell above we have seen that the document *Social Media for Recruitment* is not considered the most relevant anymore.  \n",
    "\n",
    "What we want to do to handle this is to run our KNN vector stage as a **first stage retrieval** algorithm. This is the `knn retriever` in the example below. \n",
    "We then run a **first stage retrieval** rescoring algorithm on a smaller set of these documents. For this we can use higher fidelity ColPali vectors. This is the `rescore` section in the example below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b04f7fb8-a787-42bc-90a5-c20282cefc0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_250.jpg\" alt=\"image_250.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "query = \"What do companies use for recruiting?\"\n",
    "col_pali_vector = create_col_pali_query_vectors(query)\n",
    "avg_vector = to_avg_vector(col_pali_vector)\n",
    "es_query = {\n",
    "    \"_source\": False,\n",
    "    \"retriever\": {\n",
    "        \"rescorer\": {\n",
    "            \"retriever\": {\n",
    "                \"knn\": {\n",
    "                    \"field\": \"avg_vector\",\n",
    "                    \"query_vector\": avg_vector,\n",
    "                    \"k\": 10,\n",
    "                    \"num_candidates\": 100,\n",
    "                }\n",
    "            },\n",
    "            \"rescore\": {\n",
    "                \"window_size\": 10,\n",
    "                \"query\": {\n",
    "                    \"rescore_query\": {\n",
    "                        \"script_score\": {\n",
    "                            \"query\": {\"match_all\": {}},\n",
    "                            \"script\": {\n",
    "                                \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
    "                                \"params\": {\"query_vector\": col_pali_vector},\n",
    "                            },\n",
    "                        }\n",
    "                    }\n",
    "                },\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"size\": 5,\n",
    "}\n",
    "\n",
    "results = es.search(index=INDEX_NAME, body=es_query)\n",
    "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
    "\n",
    "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
    "for image_id in image_ids:\n",
    "    image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
    "    html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
    "html += \"</div>\"\n",
    "\n",
    "display(HTML(html))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5347aef4-8ac0-48ec-a2cb-10745ec1f487",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
    "print(\"Shutting down the kernel to free memory...\")\n",
    "import os\n",
    "\n",
    "os._exit(0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dependecy-test-colpali-blog",
   "language": "python",
   "name": "dependecy-test-colpali-blog"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
